{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Anchor explanations for income prediction\n",
    "\n",
    "In this example, we will explain predictions of a Random Forest classifier whether a person will make more or less than $50k based on characteristics like age, marital status, gender or occupation. The features are a mixture of ordinal and categorical data and will be pre-processed accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
    "from alibi.explainers import AnchorTabular\n",
    "from alibi.datasets import fetch_adult"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load adult dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `fetch_adult` function returns a `Bunch` object containing the features, the targets, the feature names and a mapping of categorical variables to numbers which are required for formatting the output of the Anchor explainer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['data', 'target', 'feature_names', 'target_names', 'category_map'])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adult = fetch_adult()\n",
    "adult.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = adult.data\n",
    "target = adult.target\n",
    "feature_names = adult.feature_names\n",
    "category_map = adult.category_map"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that for your own datasets you can use our utility function [gen_category_map](../api/alibi.utils.data.rst) to create the category map:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from alibi.utils import gen_category_map"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define shuffled training and test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(0)\n",
    "data_perm = np.random.permutation(np.c_[data, target])\n",
    "data = data_perm[:,:-1]\n",
    "target = data_perm[:,-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 30000\n",
    "X_train,Y_train = data[:idx,:], target[:idx]\n",
    "X_test, Y_test = data[idx+1:,:], target[idx+1:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create feature transformation pipeline\n",
    "Create feature pre-processor. Needs to have 'fit' and 'transform' methods. Different types of pre-processing can be applied to all or part of the features. In the example below we will standardize ordinal features and apply one-hot-encoding to categorical features.\n",
    "\n",
    "Ordinal features:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "ordinal_features = [x for x in range(len(feature_names)) if x not in list(category_map.keys())]\n",
    "ordinal_transformer = Pipeline(steps=[('imputer', SimpleImputer(strategy='median')),\n",
    "                                      ('scaler', StandardScaler())])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Categorical features:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "categorical_features = list(category_map.keys())\n",
    "categorical_transformer = Pipeline(steps=[('imputer', SimpleImputer(strategy='median')),\n",
    "                                          ('onehot', OneHotEncoder(handle_unknown='ignore'))])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Combine and fit:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ColumnTransformer(n_jobs=None, remainder='drop', sparse_threshold=0.3,\n",
       "                  transformer_weights=None,\n",
       "                  transformers=[('num',\n",
       "                                 Pipeline(memory=None,\n",
       "                                          steps=[('imputer',\n",
       "                                                  SimpleImputer(add_indicator=False,\n",
       "                                                                copy=True,\n",
       "                                                                fill_value=None,\n",
       "                                                                missing_values=nan,\n",
       "                                                                strategy='median',\n",
       "                                                                verbose=0)),\n",
       "                                                 ('scaler',\n",
       "                                                  StandardScaler(copy=True,\n",
       "                                                                 with_mean=True,\n",
       "                                                                 with_std=True))],\n",
       "                                          verbose=False),\n",
       "                                 [0, 8, 9, 10]),\n",
       "                                ('cat',\n",
       "                                 Pipeline(memory=None,\n",
       "                                          steps=[('imputer',\n",
       "                                                  SimpleImputer(add_indicator=False,\n",
       "                                                                copy=True,\n",
       "                                                                fill_value=None,\n",
       "                                                                missing_values=nan,\n",
       "                                                                strategy='median',\n",
       "                                                                verbose=0)),\n",
       "                                                 ('onehot',\n",
       "                                                  OneHotEncoder(categories='auto',\n",
       "                                                                drop=None,\n",
       "                                                                dtype=<class 'numpy.float64'>,\n",
       "                                                                handle_unknown='ignore',\n",
       "                                                                sparse=True))],\n",
       "                                          verbose=False),\n",
       "                                 [1, 2, 3, 4, 5, 6, 7, 11])],\n",
       "                  verbose=False)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preprocessor = ColumnTransformer(transformers=[('num', ordinal_transformer, ordinal_features),\n",
    "                                               ('cat', categorical_transformer, categorical_features)])\n",
    "preprocessor.fit(X_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train Random Forest model\n",
    "\n",
    "Fit on pre-processed (imputing, OHE, standardizing) data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,\n",
       "                       criterion='gini', max_depth=None, max_features='auto',\n",
       "                       max_leaf_nodes=None, max_samples=None,\n",
       "                       min_impurity_decrease=0.0, min_impurity_split=None,\n",
       "                       min_samples_leaf=1, min_samples_split=2,\n",
       "                       min_weight_fraction_leaf=0.0, n_estimators=50,\n",
       "                       n_jobs=None, oob_score=False, random_state=None,\n",
       "                       verbose=0, warm_start=False)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(0)\n",
    "clf = RandomForestClassifier(n_estimators=50)\n",
    "clf.fit(preprocessor.transform(X_train), Y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define predict function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train accuracy:  0.9655333333333334\n",
      "Test accuracy:  0.855859375\n"
     ]
    }
   ],
   "source": [
    "predict_fn = lambda x: clf.predict(preprocessor.transform(x))\n",
    "print('Train accuracy: ', accuracy_score(Y_train, predict_fn(X_train)))\n",
    "print('Test accuracy: ', accuracy_score(Y_test, predict_fn(X_test)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize and fit anchor explainer for tabular data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "explainer = AnchorTabular(predict_fn, feature_names, categorical_names=category_map, seed=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Discretize the ordinal features into quartiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AnchorTabular(meta={\n",
       "    'name': 'AnchorTabular',\n",
       "    'type': ['blackbox'],\n",
       "    'explanations': ['local'],\n",
       "    'params': {'seed': 1, 'disc_perc': [25, 50, 75]}\n",
       "})"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "explainer.fit(X_train, disc_perc=[25, 50, 75])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Getting an anchor\n",
    "\n",
    "Below, we get an anchor for the prediction of the first observation in the test set. An anchor is a sufficient condition - that is, when the anchor holds, the prediction should be the same as the prediction for this instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction:  <=50K\n"
     ]
    }
   ],
   "source": [
    "idx = 0\n",
    "class_names = adult.target_names\n",
    "print('Prediction: ', class_names[explainer.predictor(X_test[idx].reshape(1, -1))[0]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We set the precision threshold to 0.95. This means that predictions on observations where the anchor holds will be the same as the prediction on the explained instance at least 95% of the time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Anchor: Marital Status = Separated AND Sex = Female\n",
      "Precision: 0.95\n",
      "Coverage: 0.18\n"
     ]
    }
   ],
   "source": [
    "explanation = explainer.explain(X_test[idx], threshold=0.95)\n",
    "print('Anchor: %s' % (' AND '.join(explanation.anchor)))\n",
    "print('Precision: %.2f' % explanation.precision)\n",
    "print('Coverage: %.2f' % explanation.coverage)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ...or not?\n",
    "\n",
    "Let's try getting an anchor for a different observation in the test set - one for the which the prediction is `>50K`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction:  >50K\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Could not find an result satisfying the 0.95 precision constraint. Now returning the best non-eligible result.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Anchor: Capital Loss > 0.00 AND Relationship = Husband AND Marital Status = Married AND Age > 37.00 AND Race = White AND Country = United-States AND Sex = Male\n",
      "Precision: 0.71\n",
      "Coverage: 0.05\n"
     ]
    }
   ],
   "source": [
    "idx = 6\n",
    "class_names = adult.target_names\n",
    "print('Prediction: ', class_names[explainer.predictor(X_test[idx].reshape(1, -1))[0]])\n",
    "\n",
    "explanation = explainer.explain(X_test[idx], threshold=0.95)\n",
    "print('Anchor: %s' % (' AND '.join(explanation.anchor)))\n",
    "print('Precision: %.2f' % explanation.precision)\n",
    "print('Coverage: %.2f' % explanation.coverage)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice how no anchor is found!\n",
    "\n",
    "This is due to the imbalanced dataset (roughly 25:75 high:low earner proportion), so during the sampling stage feature ranges corresponding to low-earners will be oversampled. This is a feature because it can point out an imbalanced dataset, but it can also be fixed by producing balanced datasets to enable anchors to be found for either class."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
